-
Notifications
You must be signed in to change notification settings - Fork 6.1k
Propose to refactor output normalization in several transformers #11850
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
tolgacangoz
wants to merge
12
commits into
huggingface:main
Choose a base branch
from
tolgacangoz:transfer-shift_scale_norm-to-AdaLayerNorm
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Propose to refactor output normalization in several transformers #11850
tolgacangoz
wants to merge
12
commits into
huggingface:main
from
tolgacangoz:transfer-shift_scale_norm-to-AdaLayerNorm
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Replace the final `FP32LayerNorm` and manual shift/scale application with a single `AdaLayerNorm` module in both the `WanTransformer3DModel` and `WanVACETransformer3DModel`. This change simplifies the forward pass by encapsulating the adaptive normalization logic within the `AdaLayerNorm` layer, removing the need for a separate `scale_shift_table`. The `_no_split_modules` list is also updated to include `norm_out` for compatibility with model parallelism.
…anVACE transformers
Updates the key mapping for the `head.modulation` layer to `norm_out.linear` in the model conversion script. This correction ensures that weights are loaded correctly for both standard and VACE transformer models.
… in Wan and WanVACE transformers
Replaces the manual implementation of adaptive layer normalization, which used a separate `scale_shift_table` and `nn.LayerNorm`, with the unified `AdaLayerNorm` module. This change simplifies the forward pass logic in several transformer models by encapsulating the normalization and modulation steps into a single component. It also adds `norm_out` to `_no_split_modules` for model parallelism compatibility.
Corrects the target key for `head.modulation` to `norm_out.linear.weight`. This ensures the weights are correctly mapped to the weight parameter of the output normalization layer during model conversion for both transformer types.
Adds a default zero-initialized bias tensor for the transformer's output normalization layer if it is missing from the original state dictionary.
dad0e68
to
65639d5
Compare
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
(I attempted to make replacements if you don't mind :)
This proposed PR will be activated when the SkyReels-V2 integration PR is merged into
main
.Replace
FP32LayerNorm
withAdaLayerNorm
in theWanTransformer3DModel
,WanVACETransformer3DModel
, ..., to simplify the forward pass and enhance model parallelism compatibility.Context: #11518 (comment)